{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AJJOEf6c7I-_"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gLeUAXP5j8RJ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
        "\n",
        "import importlib\n",
        "from simulation_research.diffusion import ode_datasets\n",
        "from simulation_research.diffusion import diffusion_unet\n",
        "from simulation_research.diffusion import samplers\n",
        "from simulation_research.diffusion import diffusion as train\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(diffusion_unet)\n",
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RQwCeAGHkB3i"
      },
      "outputs": [],
      "source": [
        "dt = .1\n",
        "bs = 400\n",
        "#ds = ode_datasets.NPendulum(N=4000+bs,n=2,dt=dt)\n",
        "ds = ode_datasets.LorenzDataset(N=4000+bs,dt=dt,integration_time=7)\n",
        "\n",
        "thetas  = ds.Zs[bs:,:60]\n",
        "test_x = ds.Zs[:bs,:60]\n",
        "T_long =ds.T_long[:60]\n",
        "#thetas /=thetas.std()\n",
        "#thetas = jax.random.normal(jax.random.PRNGKey(38),thetas.shape)\n",
        "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
        "\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzMJBTiHDeuH"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap\n",
        "@jit\n",
        "def rel_err(x,y):\n",
        "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
        "\n",
        "\n",
        "kstart=10\n",
        "@jit\n",
        "def log_prediction_metric(qs):\n",
        "  k=kstart\n",
        "  z = q = qs[k:]\n",
        "  T = T_long[k:]\n",
        "  z_gt = ds.integrate(z[0],T)\n",
        "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
        "\n",
        "@jit\n",
        "def pmetric(qs):\n",
        "  log_metric = vmap(log_prediction_metric)(qs)\n",
        "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JbFiGz6W8Qgc"
      },
      "outputs": [],
      "source": [
        "x = test_x#next(dataiter())\n",
        "t = np.random.rand(x.shape[0])\n",
        "model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nUUCduoEkaYh"
      },
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "noisetype='White'#@param ['White','Pink','Brown']\n",
        "noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]\n",
        "difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
        "diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
        "        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)\n",
        "epochs = 2000#@param {'type':'integer'}\n",
        "ic_conditioning = False#@param {'type':'boolean'}\n",
        "score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4,\n",
        "                                 ic_conditioning=ic_conditioning)\n",
        "key= jax.random.PRNGKey(38)\n",
        "cond =test_x[:,:3]\n",
        "eval_scorefn = partial(score_fn,cond=cond) if ic_conditioning else score_fn\n",
        "nll = samplers.compute_nll(diff,eval_scorefn,key,test_x).mean()\n",
        "stoch_samples = samplers.stochastic_sample(diff,eval_scorefn,key,test_x.shape,N=1000,traj=False)\n",
        "err = pmetric(stoch_samples)[0]\n",
        "print(f\"{noise.__name__} gets NLL {nll:.3f} and err {err:.3f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P1VdapJf4Omo"
      },
      "outputs": [],
      "source": [
        "samples = noise.sample(jax.random.PRNGKey(39),x.shape)\n",
        "samples2 = jax.random.normal(jax.random.PRNGKey(39),x.shape)\n",
        "samples2 = jnp.cumsum(samples2,axis=1)\n",
        "half_x = x[:,:x.shape[1]//2]\n",
        "samples2 = jnp.concatenate([half_x,half_x[:,::-1]],axis=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TY5VdTAx4Xlz"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=15 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(T_long,samples[i,:,0].T,alpha=1/2,label='brown1')\n",
        "plt.plot(T_long,samples2[i,:,0].T,alpha=1/2,label='data')\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zJk34SSU5r64"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=5 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "fourier_mag1 = jnp.abs(jnp.fft.rfft(samples[::25,:,-1],axis=1))\n",
        "fourier_mag2 = jnp.abs(jnp.fft.rfft(samples2[::25,:,-1],axis=1))\n",
        "plt.plot(fourier_mag1.T,alpha=1/5,label='brown1',color='y')\n",
        "plt.plot(fourier_mag2.T,alpha=1/5,label='data',color='brown')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "# plt.xlabel('Time t')\n",
        "# plt.ylabel(r'State')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fRn9pR8Ewzey"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y-P8uVpcWP-x"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "key= jax.random.PRNGKey(38)\n",
        "samplers.compute_nll(diff,score_fn,key,x).mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8nc8LmPLL_z4"
      },
      "source": [
        "Sample generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kboc2IIeteh7"
      },
      "outputs": [],
      "source": [
        "stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "sample_traj = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=True)\n",
        "det_samples = samplers.sample(diff,score_fn,key,x[:30].shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5rrpbT2MAWRn"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=6 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(T_long,sample_traj[0::100,i,:,0].T,alpha=1/2)\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "#plt.ylim(-5,5)\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LerwGMP9Bla"
      },
      "outputs": [],
      "source": [
        "from jax import vmap\n",
        "n=sample_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(sample_traj,ts).reshape(sample_traj.shape)\n",
        "best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xr1eZyO-1IWB"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "\n",
        "i=4 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "cmap='inferno'\n",
        "\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = best_reconstructions[100::25,i,:,-1].T\n",
        "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ohSLTVT9f2M"
      },
      "outputs": [],
      "source": [
        "from scipy.ndimage import correlate1d\n",
        "i=22 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "print(vs.shape)\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = vs[100::25,i,:,-1].T\n",
        "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'$\\dot \\theta$')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fiq7MGQm-1x-"
      },
      "outputs": [],
      "source": [
        "i=15 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = sample_traj.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[0::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSCNVbsV-zlI"
      },
      "outputs": [],
      "source": [
        "i=8 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = best_reconstructions.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[100::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fa5_fj-mU_zw"
      },
      "outputs": [],
      "source": [
        "\n",
        "import matplotlib.pyplot as plt\n",
        "i=10 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(T_long,x[i,:,-1])\n",
        "plt.plot(T_long,det_samples[i,:,-1])\n",
        "plt.plot(T_long,stoch_samples[i,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT',r'Model (ODE)', r'Model (SDE)'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QFHO0tDBL42D"
      },
      "source": [
        "Test ability to condition model on previous timesteps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cH8P02jcfKRH"
      },
      "outputs": [],
      "source": [
        "from jax import grad,jit\n",
        "condition_amount = 10# @param {type:\"slider\", min:0, max:50, step:1}\n",
        "mb = x[:30,:]\n",
        "data_std = x.std()\n",
        "\n",
        "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score\n",
        "  return conditioned_scores\n",
        "\n",
        "def inpainting_scores2(diffusion,scorefn,observed_values,slc,scale=300.):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    def constraint(xt):\n",
        "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
        "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
        "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*scale*diff.scale(t)**2/diff.sigma(t)**2\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score#.reshape(-1)\n",
        "  return jit(conditioned_scores)\n",
        "\n",
        "slc = slice(condition_amount)\n",
        "conditioned_samples = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,N=1000,traj=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kWWztvtoUljD"
      },
      "outputs": [],
      "source": [
        "expanded = (mb[None]+jnp.zeros((10,1,1,1))).reshape(mb.shape[0]*10,*mb.shape[1:])#[:,slc]\n",
        "predictions = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,N=2000,traj=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "92hg5RnQW6jn"
      },
      "outputs": [],
      "source": [
        "z_pert = vmap(ds.integrate,(0,None),0)(mb[:,0]+1e-3*np.random.randn(*mb[:,0].shape),T_long)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3qOJfbZFYLM9"
      },
      "outputs": [],
      "source": [
        "preds = predictions.reshape(10,-1,*predictions.shape[1:])\n",
        "lower = np.percentile(preds.mean(-1),10,axis=0)\n",
        "upper = np.percentile(preds.mean(-1),90,axis=0)\n",
        "for i in range(mb.shape[0]):\n",
        "  if i\u003e10: break\n",
        "  plt.plot(T_long,mb[i].mean(-1))\n",
        "  #plt.plot(T_long,z_pert[i].mean(-1))\n",
        "  plt.fill_between(T_long,lower[i],upper[i],alpha=.3,color='y')\n",
        "  plt.plot()\n",
        "  #plt.yscale('log')\n",
        "  plt.xlabel('Time')\n",
        "  plt.ylabel('State sum')\n",
        "  plt.legend(['Ground Truth','Model 10-90 percentiles'])\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nC9BdYcQVcL1"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Ps8bGfzWhX4"
      },
      "outputs": [],
      "source": [
        "lower.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0h19RNwvR4is"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap,random\n",
        "\n",
        "\n",
        "@jit\n",
        "def rel_err(z1,z2):\n",
        "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)+jnp.abs(z2).sum(-1)))\n",
        "\n",
        "gt = x[:30]\n",
        "#for pred in [conditioned_samples[-1]]:#,conditioned_sample]:\n",
        "for scale in [10.,100.,300.,1000.,3000.]:\n",
        "  pred = samplers.stochastic_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc,scale=scale),key,mb.shape,N=2000,traj=False)\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,gt),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T_long,rel_errs,label=f\"r={scale}\")\n",
        "  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend()#//['SDE completion','ODE completion'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUMkKWJuAy2y"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=1 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "#plt.plot(T_long,conditioned_samples[-600::100,i,:,0].T,zorder=0,alpha=.2)\n",
        "plt.plot(T_long,conditioned_samples[-1,i,  :,0].T,zorder=2,label='model')\n",
        "plt.plot(T_long,x[i,:,0],label='gt',alpha=1,zorder=99)\n",
        "plt.plot(T_long[slc],x[i,slc,0],label='cond',alpha=1,zorder=100,lw=3)\n",
        "\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "#plt.ylim(-3,3)\n",
        "plt.legend()\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "21fgx9UbckBK"
      },
      "outputs": [],
      "source": [
        "conditioned_sample = samplers.sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xKR36dO6FgwA"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap,random\n",
        "\n",
        "@jit\n",
        "def rel_err(z1,z2):\n",
        "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))\n",
        "\n",
        "gt = x[:30]\n",
        "for pred in [conditioned_samples[-1]]:#,conditioned_sample]:\n",
        "  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T_long,rel_errs)\n",
        "  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylim(1e-4,1)\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['SDE completion'])#,'ODE completion'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Xfy7NxAfQ8S"
      },
      "outputs": [],
      "source": [
        "i=11 # @param {type:\"slider\", min:0, max:29, step:1}\n",
        "plt.plot(T_long,x[i,:,1])\n",
        "plt.plot(T_long[slc],x[i,slc,1],lw=3)\n",
        "plt.plot(T_long,conditioned_sample[i,:,1])\n",
        "plt.plot(T_long,conditioned_samples[-1,i,:,1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT','Conditioning', 'ODE rollout','SDE rollout'])\n",
        "#plt.ylim(-3,3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMvSdxUUG9Op"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mc0IHMpaDrMR"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fAyaiDRrMDd5"
      },
      "source": [
        "Unconditional Prediction quality"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ec8x-pzpwIUg"
      },
      "outputs": [],
      "source": [
        "# stoch_samples = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=1000,traj=False)\n",
        "# det_samples = samplers.sample(diff,score_fn,key,x[:30].shape)\n",
        "print(f'ODE performance {pmetric(det_samples)[0]}')\n",
        "print(f'SDE performance {pmetric(stoch_samples)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fg2JsQacJrBV"
      },
      "outputs": [],
      "source": [
        "from jax import random\n",
        "key = random.PRNGKey(45)\n",
        "#s=s2#,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=500,smin=sigma_min,smax=sigma_max)\n",
        "s = stoch_samples#energy_samples_det#stoch_samples\n",
        "\n",
        "k = 5\n",
        "z = q = s[:,k:]\n",
        "T = T_long[k:]\n",
        "z0 = z[:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b-DH4XmZMwWX"
      },
      "outputs": [],
      "source": [
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tQ0FJA2M94F"
      },
      "source": [
        "Compared trajectories"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oxMDdMV1CiOd"
      },
      "outputs": [],
      "source": [
        "ds.animate(z[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FgjBdB5aDkz2"
      },
      "outputs": [],
      "source": [
        "ds.animate(z_gts[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHKOgFXNJxY6"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_pert[i,:,0])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend(['gt','model','pert'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEMilLlTJyzM"
      },
      "outputs": [],
      "source": [
        "for i in range(20):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
        "  line5, = ax.plot(T,z[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'x gt',r'x model',r'z gt', r'z model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1tEXOhfcMmho"
      },
      "outputs": [],
      "source": [
        "metric_vals =[]\n",
        "metric_stds = []\n",
        "Ns = [25,50,100,200,500,1000,2000]\n",
        "for N in Ns:\n",
        "  s = samplers.stochastic_sample(diff,score_fn,key,x[:30].shape,N=N)\n",
        "  mean,std = pmetric(s)\n",
        "  metric_vals.append(mean)\n",
        "  metric_stds.append(std)\n",
        "metric_vals = np.array(metric_vals)\n",
        "metric_stds = np.array(metric_stds)\n",
        "\n",
        "plt.plot(Ns,metric_vals)\n",
        "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
        "plt.xlabel('Sampler steps')\n",
        "plt.ylabel('Pmetric value')\n",
        "plt.xscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qt2vnvKFoBuo"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x-998NV2Jqsq"
      },
      "outputs": [],
      "source": [
        "T = ds.T_long\n",
        "z = stoch_samples\n",
        "z0 = test_x[:,0]\n",
        "z_gts = test_x#vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
        "kind": "private"
      },
      "name": "lorenz.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
